{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "5b12027d-8164-4e7e-8a64-226d6f932419",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain.utilities import SQLDatabase\n",
    "from langchain.llms import OpenAI\n",
    "from langchain_experimental.sql import SQLDatabaseChain"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "9bfe39d0-9284-4065-9c3a-a1e25f4ecedc",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['OPENAI_API_KEY'] = ''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "0b6aeecb-94ee-441b-ac18-4f5572834acd",
   "metadata": {},
   "outputs": [],
   "source": [
    "db = SQLDatabase.from_uri(\"mysql+pymysql://lang:lang@127.0.0.1:3306/Chinook?charset=utf8\")\n",
    "llm = OpenAI(temperature=0, verbose=True)\n",
    "db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "53aa0147-8d88-4aeb-b41b-93bfb910b67e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\u001b[1m> Entering new SQLDatabaseChain chain...\u001b[0m\n",
      "How many employees are there?\n",
      "SQLQuery:\u001b[32;1m\u001b[1;3mSELECT COUNT(*) AS `NumberOfEmployees` FROM `Employee`;\u001b[0m\n",
      "SQLResult: \u001b[33;1m\u001b[1;3m[(8,)]\u001b[0m\n",
      "Answer:\u001b[32;1m\u001b[1;3mThere are 8 employees.\u001b[0m\n",
      "\u001b[1m> Finished chain.\u001b[0m\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'There are 8 employees.'"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "db_chain.run(\"How many employees are there?\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "b5060c61-8411-4520-9f01-305c2f5f4a0e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SELECT COUNT(*) FROM Employee\n"
     ]
    }
   ],
   "source": [
    "from langchain.chat_models import ChatOpenAI\n",
    "from langchain.chains import create_sql_query_chain\n",
    "chain = create_sql_query_chain(ChatOpenAI(temperature=0), db)\n",
    "response = chain.invoke({\"question\":\"How many employees are there\"})\n",
    "print(response)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "795291ab-4031-49f2-8039-5d462b322b6c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'[(8,)]'"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "db.run(response)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "815f9581-03cf-4e36-ba20-ed75df3edfd0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
      "\u001b[32;1m\u001b[1;3mAction: sql_db_list_tables\n",
      "Action Input: \u001b[0m\n",
      "Observation: \u001b[38;5;200m\u001b[1;3mAlbum, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track\u001b[0m\n",
      "Thought:\u001b[32;1m\u001b[1;3m I should query the schema of the Invoice and Customer tables.\n",
      "Action: sql_db_schema\n",
      "Action Input: Invoice, Customer\u001b[0m\n",
      "Observation: \u001b[33;1m\u001b[1;3m\n",
      "CREATE TABLE `Customer` (\n",
      "\t`CustomerId` INTEGER(11) NOT NULL, \n",
      "\t`FirstName` VARCHAR(40) CHARACTER SET utf8 NOT NULL, \n",
      "\t`LastName` VARCHAR(20) CHARACTER SET utf8 NOT NULL, \n",
      "\t`Company` VARCHAR(80) CHARACTER SET utf8, \n",
      "\t`Address` VARCHAR(70) CHARACTER SET utf8, \n",
      "\t`City` VARCHAR(40) CHARACTER SET utf8, \n",
      "\t`State` VARCHAR(40) CHARACTER SET utf8, \n",
      "\t`Country` VARCHAR(40) CHARACTER SET utf8, \n",
      "\t`PostalCode` VARCHAR(10) CHARACTER SET utf8, \n",
      "\t`Phone` VARCHAR(24) CHARACTER SET utf8, \n",
      "\t`Fax` VARCHAR(24) CHARACTER SET utf8, \n",
      "\t`Email` VARCHAR(60) CHARACTER SET utf8 NOT NULL, \n",
      "\t`SupportRepId` INTEGER(11), \n",
      "\tPRIMARY KEY (`CustomerId`), \n",
      "\tCONSTRAINT `FK_CustomerSupportRepId` FOREIGN KEY(`SupportRepId`) REFERENCES `Employee` (`EmployeeId`)\n",
      ")ENGINE=InnoDB DEFAULT CHARSET=latin1\n",
      "\n",
      "/*\n",
      "3 rows from Customer table:\n",
      "CustomerId\tFirstName\tLastName\tCompany\tAddress\tCity\tState\tCountry\tPostalCode\tPhone\tFax\tEmail\tSupportRepId\n",
      "1\tLuís\tGonçalves\tEmbraer - Empresa Brasileira de Aeronáutica S.A.\tAv. Brigadeiro Faria Lima, 2170\tSão José dos Campos\tSP\tBrazil\t12227-000\t+55 (12) 3923-5555\t+55 (12) 3923-5566\tluisg@embraer.com.br\t3\n",
      "2\tLeonie\tKöhler\tNone\tTheodor-Heuss-Straße 34\tStuttgart\tNone\tGermany\t70174\t+49 0711 2842222\tNone\tleonekohler@surfeu.de\t5\n",
      "3\tFrançois\tTremblay\tNone\t1498 rue Bélanger\tMontréal\tQC\tCanada\tH2G 1A7\t+1 (514) 721-4711\tNone\tftremblay@gmail.com\t3\n",
      "*/\n",
      "\n",
      "\n",
      "CREATE TABLE `Invoice` (\n",
      "\t`InvoiceId` INTEGER(11) NOT NULL, \n",
      "\t`CustomerId` INTEGER(11) NOT NULL, \n",
      "\t`InvoiceDate` DATETIME NOT NULL, \n",
      "\t`BillingAddress` VARCHAR(70) CHARACTER SET utf8, \n",
      "\t`BillingCity` VARCHAR(40) CHARACTER SET utf8, \n",
      "\t`BillingState` VARCHAR(40) CHARACTER SET utf8, \n",
      "\t`BillingCountry` VARCHAR(40) CHARACTER SET utf8, \n",
      "\t`BillingPostalCode` VARCHAR(10) CHARACTER SET utf8, \n",
      "\t`Total` DECIMAL(10, 2) NOT NULL, \n",
      "\tPRIMARY KEY (`InvoiceId`), \n",
      "\tCONSTRAINT `FK_InvoiceCustomerId` FOREIGN KEY(`CustomerId`) REFERENCES `Customer` (`CustomerId`)\n",
      ")ENGINE=InnoDB DEFAULT CHARSET=latin1\n",
      "\n",
      "/*\n",
      "3 rows from Invoice table:\n",
      "InvoiceId\tCustomerId\tInvoiceDate\tBillingAddress\tBillingCity\tBillingState\tBillingCountry\tBillingPostalCode\tTotal\n",
      "1\t2\t2009-01-01 00:00:00\tTheodor-Heuss-Straße 34\tStuttgart\tNone\tGermany\t70174\t1.98\n",
      "2\t4\t2009-01-02 00:00:00\tUllevålsveien 14\tOslo\tNone\tNorway\t0171\t3.96\n",
      "3\t8\t2009-01-03 00:00:00\tGrétrystraat 63\tBrussels\tNone\tBelgium\t1000\t5.94\n",
      "*/\u001b[0m\n",
      "Thought:\u001b[32;1m\u001b[1;3m I should query the total sales per country.\n",
      "Action: sql_db_query\n",
      "Action Input: SELECT Country, SUM(Total) AS TotalSales FROM Invoice INNER JOIN Customer ON Invoice.CustomerId = Customer.CustomerId GROUP BY Country ORDER BY TotalSales DESC LIMIT 10;\u001b[0m\n",
      "Observation: \u001b[36;1m\u001b[1;3m[('USA', Decimal('523.06')), ('Canada', Decimal('303.96')), ('France', Decimal('195.10')), ('Brazil', Decimal('190.10')), ('Germany', Decimal('156.48')), ('United Kingdom', Decimal('112.86')), ('Czech Republic', Decimal('90.24')), ('Portugal', Decimal('77.24')), ('India', Decimal('75.26')), ('Chile', Decimal('46.62'))]\u001b[0m\n",
      "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer.\n",
      "Final Answer: USA customers spent the most, with a total of $523.06.\u001b[0m\n",
      "\n",
      "\u001b[1m> Finished chain.\u001b[0m\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'USA customers spent the most, with a total of $523.06.'"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from langchain.agents import create_sql_agent\n",
    "from langchain.agents.agent_toolkits import SQLDatabaseToolkit\n",
    "# from langchain.agents import AgentExecutor\n",
    "from langchain.agents.agent_types import AgentType\n",
    "\n",
    "db = SQLDatabase.from_uri(\"mysql+pymysql://lang:lang@127.0.0.1:3306/Chinook?charset=utf8\")\n",
    "llm = OpenAI(temperature=0, verbose=True)\n",
    "\n",
    "agent_executor = create_sql_agent(\n",
    "    llm=OpenAI(temperature=0),\n",
    "    toolkit=SQLDatabaseToolkit(db=db, llm=OpenAI(temperature=0)),\n",
    "    verbose=True,\n",
    "    agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,\n",
    ")\n",
    "agent_executor.run(\n",
    "    \"List the total sales per country. Which country's customers spent the most?\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "c350d91d-1db2-458d-85ab-b1930df5b645",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\u001b[1m> Entering new SQLDatabaseChain chain...\u001b[0m\n",
      "共有多少员工?\n",
      "SQLQuery:\u001b[32;1m\u001b[1;3mSELECT COUNT(*) FROM Employee;\u001b[0m\n",
      "SQLResult: \u001b[33;1m\u001b[1;3m[(8,)]\u001b[0m\n",
      "Answer:\u001b[32;1m\u001b[1;3m共有8名员工。\u001b[0m\n",
      "\u001b[1m> Finished chain.\u001b[0m\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'共有8名员工。'"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "db_chain.run(\"共有多少员工?\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "cd7da32e-2cce-40d1-9b79-7fdfa542d273",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
      "\u001b[32;1m\u001b[1;3mAction: sql_db_list_tables\n",
      "Action Input: \u001b[0m\n",
      "Observation: \u001b[38;5;200m\u001b[1;3mAlbum, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track\u001b[0m\n",
      "Thought:\u001b[32;1m\u001b[1;3m I should query the schema of the Customer and Invoice tables.\n",
      "Action: sql_db_schema\n",
      "Action Input: Customer, Invoice\u001b[0m\n",
      "Observation: \u001b[33;1m\u001b[1;3m\n",
      "CREATE TABLE `Customer` (\n",
      "\t`CustomerId` INTEGER(11) NOT NULL, \n",
      "\t`FirstName` VARCHAR(40) CHARACTER SET utf8 NOT NULL, \n",
      "\t`LastName` VARCHAR(20) CHARACTER SET utf8 NOT NULL, \n",
      "\t`Company` VARCHAR(80) CHARACTER SET utf8, \n",
      "\t`Address` VARCHAR(70) CHARACTER SET utf8, \n",
      "\t`City` VARCHAR(40) CHARACTER SET utf8, \n",
      "\t`State` VARCHAR(40) CHARACTER SET utf8, \n",
      "\t`Country` VARCHAR(40) CHARACTER SET utf8, \n",
      "\t`PostalCode` VARCHAR(10) CHARACTER SET utf8, \n",
      "\t`Phone` VARCHAR(24) CHARACTER SET utf8, \n",
      "\t`Fax` VARCHAR(24) CHARACTER SET utf8, \n",
      "\t`Email` VARCHAR(60) CHARACTER SET utf8 NOT NULL, \n",
      "\t`SupportRepId` INTEGER(11), \n",
      "\tPRIMARY KEY (`CustomerId`), \n",
      "\tCONSTRAINT `FK_CustomerSupportRepId` FOREIGN KEY(`SupportRepId`) REFERENCES `Employee` (`EmployeeId`)\n",
      ")ENGINE=InnoDB DEFAULT CHARSET=latin1\n",
      "\n",
      "/*\n",
      "3 rows from Customer table:\n",
      "CustomerId\tFirstName\tLastName\tCompany\tAddress\tCity\tState\tCountry\tPostalCode\tPhone\tFax\tEmail\tSupportRepId\n",
      "1\tLuís\tGonçalves\tEmbraer - Empresa Brasileira de Aeronáutica S.A.\tAv. Brigadeiro Faria Lima, 2170\tSão José dos Campos\tSP\tBrazil\t12227-000\t+55 (12) 3923-5555\t+55 (12) 3923-5566\tluisg@embraer.com.br\t3\n",
      "2\tLeonie\tKöhler\tNone\tTheodor-Heuss-Straße 34\tStuttgart\tNone\tGermany\t70174\t+49 0711 2842222\tNone\tleonekohler@surfeu.de\t5\n",
      "3\tFrançois\tTremblay\tNone\t1498 rue Bélanger\tMontréal\tQC\tCanada\tH2G 1A7\t+1 (514) 721-4711\tNone\tftremblay@gmail.com\t3\n",
      "*/\n",
      "\n",
      "\n",
      "CREATE TABLE `Invoice` (\n",
      "\t`InvoiceId` INTEGER(11) NOT NULL, \n",
      "\t`CustomerId` INTEGER(11) NOT NULL, \n",
      "\t`InvoiceDate` DATETIME NOT NULL, \n",
      "\t`BillingAddress` VARCHAR(70) CHARACTER SET utf8, \n",
      "\t`BillingCity` VARCHAR(40) CHARACTER SET utf8, \n",
      "\t`BillingState` VARCHAR(40) CHARACTER SET utf8, \n",
      "\t`BillingCountry` VARCHAR(40) CHARACTER SET utf8, \n",
      "\t`BillingPostalCode` VARCHAR(10) CHARACTER SET utf8, \n",
      "\t`Total` DECIMAL(10, 2) NOT NULL, \n",
      "\tPRIMARY KEY (`InvoiceId`), \n",
      "\tCONSTRAINT `FK_InvoiceCustomerId` FOREIGN KEY(`CustomerId`) REFERENCES `Customer` (`CustomerId`)\n",
      ")ENGINE=InnoDB DEFAULT CHARSET=latin1\n",
      "\n",
      "/*\n",
      "3 rows from Invoice table:\n",
      "InvoiceId\tCustomerId\tInvoiceDate\tBillingAddress\tBillingCity\tBillingState\tBillingCountry\tBillingPostalCode\tTotal\n",
      "1\t2\t2009-01-01 00:00:00\tTheodor-Heuss-Straße 34\tStuttgart\tNone\tGermany\t70174\t1.98\n",
      "2\t4\t2009-01-02 00:00:00\tUllevålsveien 14\tOslo\tNone\tNorway\t0171\t3.96\n",
      "3\t8\t2009-01-03 00:00:00\tGrétrystraat 63\tBrussels\tNone\tBelgium\t1000\t5.94\n",
      "*/\u001b[0m\n",
      "Thought:\u001b[32;1m\u001b[1;3m I should query the total sales for each country and order the results by total sales.\n",
      "Action: sql_db_query\n",
      "Action Input: SELECT Country, SUM(Total) AS TotalSales FROM Customer c INNER JOIN Invoice i ON c.CustomerId = i.CustomerId GROUP BY Country ORDER BY TotalSales DESC LIMIT 10;\u001b[0m\n",
      "Observation: \u001b[36;1m\u001b[1;3m[('USA', Decimal('523.06')), ('Canada', Decimal('303.96')), ('France', Decimal('195.10')), ('Brazil', Decimal('190.10')), ('Germany', Decimal('156.48')), ('United Kingdom', Decimal('112.86')), ('Czech Republic', Decimal('90.24')), ('Portugal', Decimal('77.24')), ('India', Decimal('75.26')), ('Chile', Decimal('46.62'))]\u001b[0m\n",
      "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
      "Final Answer: 美国的客户花费最多，总销售额为523.06美元。\u001b[0m\n",
      "\n",
      "\u001b[1m> Finished chain.\u001b[0m\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'美国的客户花费最多，总销售额为523.06美元。'"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "agent_executor.run(\n",
    "    \"列出每个国家/地区的总销售额。哪个国家的客户花费最多？\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c503071-32ae-431f-a3bb-64fdc1f29026",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
